(ns proteins
(:require [tablecloth.api :as tc]
[fastmath.core :as math]
[fastmath.random :as random]
[tech.v3.datatype :as dtype]
[tech.v3.dataset :as dataset]
[tech.v3.tensor :as tensor]
[tech.v3.datatype.functional :as fun]
[aerial.hanami.common :as hc]
[aerial.hanami.templates :as ht]
[scicloj.kindly.v3.kind :as kind]
[scicloj.kindly.v3.api :as kindly]
[scicloj.clay.v2.api :as clay]
[libpython-clj2.python :refer [py. py.. py.-] :as py]
[scicloj.noj.v1.vis :as vis]
[scicloj.noj.v1.vis.python :as vis.python]
[libpython-clj2.require :refer [require-python]]
[util])
(:import java.lang.Math))(require-python '[builtins :as python]
'operator
'[arviz :as az]
'[arviz.style :as az.style]
'[pandas :as pd]
'[matplotlib.pyplot :as plt]
'[numpy :as np]
'[numpy.random :as np.random]
'[pymc :as pm]
'[Bio.PDB.PDBParser]
'[Bio.PDB]
'[Bio.PDB.Polypeptide]
'[pytensor]
'[pytensor.tensor :as pt]
'[math]):ok
(def protein-name1 "7ju5clean")(def protein-name2 "AF-A0A024R7T2-F1-model_v4-clean")(defn extract-coordinates-from-pdb
([protein-name]
(let [filepath (str "data/" protein-name ".pdb")
parser (Bio.PDB/PDBParser)
structure (py. parser get_structure protein-name filepath)]
(-> structure
first
((fn [model]
(-> model
(->> (mapcat
(fn [chain]
(->> chain
(filter (fn [residue]
(-> residue
(py. get_resname)
(Bio.PDB.Polypeptide/is_aa :standard true))))
(map (fn [residue]
{:id (-> residue
(py. get_id)
second)
:name (-> residue
(py. get_resname))
:ca-coordinates (try
(-> residue
(util/brackets "CA")
(py. get_coord)
(->> (dtype/->array :float32)))
(catch Exception e nil))}))
(filter :ca-coordinates))))
tc/dataset))))))))(-> protein-name1
extract-coordinates-from-pdb
;; for readability of output:
(tc/update-columns [:ca-coordinates]
(partial map vec)))(defn center-1d [xs]
(fun/- xs
(fun/mean xs)))(defn center-columns [xyzs]
(-> xyzs
(tensor/map-axis center-1d 0)))(defn read-data
([prots]
(read-data prots nil))
([prots {:keys [limit]}]
(let [prots [protein-name1 protein-name2]
[dataset1 dataset2] (->> prots
(map extract-coordinates-from-pdb))
joined-dataset (-> (tc/inner-join dataset1 dataset2 :id)
((if limit
#(tc/head % limit)
identity)))
coords (->> [:ca-coordinates :right.ca-coordinates]
(map (fn [colname]
(-> colname
joined-dataset
tensor/->tensor))))
obs (->> coords
(mapv #(tensor/map-axis % center-1d 0)))
obs-datasets (->> obs
(mapv util/xyz-tensor->dataset))]
{:coords coords
:obs obs
:obs-datasets obs-datasets})))(-> [protein-name1 protein-name2]
(read-data {:limit 4})
:obs-datasets)Compare the datasets visually
(let [{:keys [obs obs-datasets]} (-> [protein-name1 protein-name2]
read-data)
structures (->> obs
(mapv #(-> %
(tensor/transpose [1 0]))))
view-limit 50
tensor->cljs (fn [tensor]
(-> tensor
(tensor/transpose [1 0])
util/xyz-tensor->dataset
(tc/head view-limit)
util/prep-dataset-for-cljs))]
(->> {:prot1-dataset (-> structures
first
tensor->cljs)
:prot2-dataset (-> structures
second
tensor->cljs)}
(vector '(fn [{:keys [prot1-dataset
prot2-dataset]}]
[plotly
{:data [(-> prot1-dataset
(merge {:type :scatter3d
:mode :lines+markers
:opacity 1
:marker {:size 3
:color "purple"}}))
(-> prot2-dataset
(merge {:type :scatter3d
:mode :lines+markers
:opacity 1
:marker {:size 3
:color "orange"}}))]}]))
kind/hiccup))(defn rotate-q [u]
(let [theta1 (-> u
(util/brackets 1)
(operator/mul (* 2 Math/PI)))
theta2 (-> u
(util/brackets 2)
(operator/mul (* 2 Math/PI)))
r1 (-> u
(util/brackets 0)
(->> (operator/sub 1))
pt/sqrt)
r2 (-> u
(util/brackets 0)
pt/sqrt)
w (-> theta2
(pt/cos)
(operator/mul r2))
x (-> theta1
(pt/sin)
(operator/mul r1))
y (-> theta1
(pt/cos)
(operator/mul r1))
z (-> theta2
(pt/sin)
(operator/mul r2))
R00 (operator/sub (operator/add (pt/sqr w)
(pt/sqr x))
(operator/add (pt/sqr y)
(pt/sqr z)))
R11 (operator/sub (operator/add (pt/sqr w)
(pt/sqr y))
(operator/add (pt/sqr x)
(pt/sqr z)))
R22 (operator/sub (operator/add (pt/sqr w)
(pt/sqr z))
(operator/add (pt/sqr x)
(pt/sqr y)))
R01 (operator/mul 2
(operator/sub (operator/mul x y)
(operator/mul w z)))
R02 (operator/mul 2
(operator/add (operator/mul x z)
(operator/mul w y)))
R10 (operator/mul 2
(operator/add (operator/mul x y)
(operator/mul w z)))
R12 (operator/mul 2
(operator/sub (operator/mul y z)
(operator/mul w x)))
R20 (operator/mul 2
(operator/sub (operator/mul x z)
(operator/mul w y)))
R21 (operator/mul 2
(operator/add (operator/mul y z)
(operator/mul w x)))]
(pt/stack [(pt/stack [R00 R01 R02])
(pt/stack [R10 R11 R12])
(pt/stack [R20 R21 R22])])))(defonce model
(memoize
(fn [{:keys [residues-limit tune]}]
(let [{:keys [obs obs-datasets]}
(read-data [protein-name1 protein-name2]
{:limit residues-limit})
structures (->> obs
(mapv #(-> %
(tensor/transpose [1 0]))))
np-structures (->> structures
(mapv util/tensor2d->np-matrix))
shape (-> (obs 0)
dtype/shape
reverse
vec)
[space-dimension n-residues] shape]
(py/with [model (pm/Model)]
(let [M (pm/Cauchy "M"
:alpha 0
:beta 1
:shape shape)
M0 (pm/Deterministic "M0"
(operator/sub
M
(pt/mean M)))
t (pm/Normal "t" :shape [space-dimension]) ; the shift
u (pm/Uniform "u" :shape [space-dimension]) ; randomization of rotation
R (pm/Deterministic "R" (rotate-q u)) ; the rotation matrix
U (pm/HalfNormal "U"
:sigma 0.01 ; TODO: Consider some prior here
:shape [n-residues])
M0_rotated (pm/Deterministic "M0_rotated"
(pt/dot R M0))
X1 (pm/MatrixNormal "X1"
:mu M0
:rowcov (np/eye space-dimension)
:colcov (pt/diag U)
:observed (np-structures 0))
X2 (pm/MatrixNormal "X2"
:mu (-> M0_rotated
;; conjugating with transpose
;; to make broadcasting work
pt/transpose
(operator/add t)
pt/transpose)
:rowcov (np/eye space-dimension)
:colcov (pt/diag U)
:observed (np-structures 1))
M0_adapted (pm/Deterministic "M0_adapted"
(-> (pt/dot R M0)
pt/transpose
(operator/add t)
pt/transpose))
X1_adapted (pm/Deterministic "X1_adapted"
(-> (pt/dot R X1)
pt/transpose
(operator/add t)
pt/transpose))
prot1_adapted (pm/Deterministic "prot1_adapted"
(-> (np-structures 0)
(->> (pt/dot R))
pt/transpose
(operator/add t)
pt/transpose))
prior-predictive-samples (pm/sample_prior_predictive)
idata (pm/sample :chains 1
:draws 200
:tune tune)
posterior-predictive-samples (pm/sample_posterior_predictive
idata)]
{:structures structures
:prior-predictive-samples prior-predictive-samples
:posterior-predictive-samples posterior-predictive-samples
:idata idata}))))))nil
(model {:residues-limit 100 :tune 15}){:structures [#tech.v3.tensor<object>[3 100]
[[ 9.595 13.34 14.12 16.27 14.92 14.56 13.25 9.845 9.544 7.460 7.627 4.379 4.230 3.901 0.7257 -2.643 -5.457 -8.799 -12.13 -12.71 -13.19 -14.33 -14.17 -14.20 -14.55 -10.74 -9.413 -10.14 -9.008 -8.636 -7.963 -4.995 -2.462 1.236 3.753 6.887 6.023 8.970 10.52 7.208 6.448 3.674 0.9007 0.05573 -3.307 -2.918 -4.943 -3.931 -4.639 -4.793 -6.048 -8.605 -8.173 -9.640 -7.165 -5.966 -2.469 -3.889 -4.879 -1.545 0.3107 -1.723 -0.9393 2.797 2.672 1.035 3.762 6.473 5.075 4.289 7.904 8.951 6.183 5.784 3.301 0.5967 -2.265 -1.028 -1.706 1.451 3.100 4.524 5.649 4.702 4.686 6.312 7.680 4.867 1.573 -0.08527 1.285 0.1167 0.4817 0.5327 -0.6753 -1.898 -5.249 -7.753 -10.45 -13.31]
[-2.687 -2.544 0.2389 -1.951 -1.808 1.865 0.7629 -0.2181 2.107 5.178 8.275 8.489 12.26 12.06 9.997 11.64 9.197 9.869 8.067 4.432 2.393 -0.6111 -2.453 -5.970 -7.223 -6.831 -4.381 -1.743 -0.2131 3.546 5.844 8.123 10.41 9.954 12.62 11.73 8.043 6.102 9.159 10.46 14.21 15.58 13.37 9.810 8.371 5.007 2.464 1.743 -1.509 -1.445 -3.468 -2.656 -2.940 -6.379 -7.934 -11.50 -12.83 -14.52 -11.13 -9.580 -12.53 -12.27 -8.518 -9.075 -12.06 -9.997 -7.318 -9.942 -11.52 -8.370 -7.280 -9.887 -9.499 -7.046 -7.463 -5.255 -7.364 -7.872 -5.137 -3.198 -2.693 0.6829 0.2549 0.3359 2.527 0.5129 3.328 5.652 4.370 1.272 -0.8891 0.4399 -1.315 1.025 -0.4131 0.5129 2.258 0.7899 3.252 0.8429]
[ 10.45 9.779 7.300 5.007 1.393 0.5175 -2.854 -1.436 1.561 0.9625 3.280 5.188 4.946 1.154 1.277 1.590 2.179 0.4605 0.3115 1.267 -1.869 0.05551 3.395 4.670 8.245 8.909 6.375 3.817 0.5865 0.3255 -2.515 -2.061 -3.602 -3.314 -3.805 -5.686 -5.875 -7.432 -9.005 -10.52 -10.66 -8.504 -7.160 -8.333 -7.506 -5.651 -3.721 -0.1415 1.696 5.476 8.384 10.97 14.72 15.08 12.63 13.37 12.75 9.671 8.168 9.016 7.442 4.212 3.912 4.453 2.109 -0.6425 -0.2335 -0.4305 -3.533 -5.528 -5.159 -7.697 -10.22 -12.98 -15.73 -17.30 -16.17 -12.69 -10.11 -9.447 -6.053 -5.377 -1.824 1.847 4.911 7.695 9.969 11.05 12.16 10.82 8.016 4.612 1.301 -1.709 -4.991 -8.413 -8.613 -11.04 -11.63 -12.24]]
#tech.v3.tensor<object>[3 100]
[[ 10.55 10.04 13.18 12.30 12.94 13.04 12.47 9.024 7.993 4.936 3.721 0.7633 -0.8127 -1.138 -3.119 -6.924 -8.353 -11.92 -14.27 -13.15 -12.74 -12.13 -11.09 -10.00 -6.972 -4.765 -5.391 -7.623 -7.801 -9.358 -9.853 -7.826 -6.778 -3.038 -2.058 1.247 1.941 5.457 5.769 2.728 0.3743 -3.187 -4.459 -4.275 -6.934 -5.118 -5.743 -4.195 -3.401 -2.981 -3.281 -6.541 -6.974 -6.376 -3.627 -1.052 2.785 2.354 0.1603 2.746 5.504 3.338 2.592 6.337 7.242 4.666 6.177 9.745 8.662 6.576 9.728 11.46 8.648 7.459 5.473 2.140 0.4083 1.730 0.02830 2.042 3.535 3.325 4.719 4.119 3.606 6.016 6.659 4.389 1.535 0.4353 2.099 0.1283 1.063 -0.06270 -0.8277 -2.638 -6.478 -8.012 -11.52 -13.32]
[-3.081 -1.371 -2.272 -4.760 -3.423 -6.518 -4.281 -3.242 -6.000 -7.965 -11.38 -10.58 -13.98 -12.68 -9.547 -9.919 -6.873 -5.764 -2.777 0.4731 3.393 6.073 6.443 9.572 11.71 10.12 6.526 3.853 3.087 -0.1568 -1.636 -4.857 -7.564 -8.316 -11.67 -12.02 -8.256 -7.422 -10.98 -10.65 -13.67 -13.07 -10.29 -6.461 -4.194 -1.975 0.4552 -0.3799 2.434 1.365 2.679 2.905 0.5062 3.514 5.214 7.549 7.247 10.48 8.700 5.842 8.441 10.33 6.893 5.893 9.346 8.743 5.234 6.668 9.470 6.969 4.747 7.694 9.300 7.719 10.04 9.491 12.20 10.96 8.031 4.865 2.583 -1.144 -2.485 -2.790 -5.425 -4.180 -7.619 -10.12 -7.678 -3.995 -1.775 -1.959 0.4041 -1.083 1.705 2.442 2.187 5.587 4.380 7.764]
[-12.86 -9.505 -7.517 -4.742 -1.212 1.113 4.222 2.853 0.3602 1.607 0.2622 -2.110 -1.086 2.548 1.378 0.9962 -0.8328 0.1012 -0.4978 -2.189 0.2842 -2.407 -6.087 -8.047 -9.103 -6.350 -7.602 -5.963 -2.196 -0.8448 2.644 3.026 5.521 5.864 7.487 9.457 9.013 10.39 11.86 14.22 14.49 13.13 10.81 10.61 9.032 6.478 3.624 0.2112 -2.262 -5.916 -9.492 -11.50 -14.48 -16.85 -14.74 -16.33 -15.96 -13.88 -11.27 -11.18 -10.39 -7.841 -6.145 -6.254 -4.792 -2.010 -1.406 -1.178 1.264 3.347 3.779 5.566 7.667 10.97 13.31 15.27 13.12 9.715 7.880 8.681 6.001 6.797 3.465 -0.3008 -3.016 -5.754 -7.391 -9.277 -10.15 -9.788 -7.127 -3.842 -0.9358 2.436 4.969 8.267 8.060 8.939 10.03 9.478]]],
:prior-predictive-samples Inference data with groups:
> prior
> prior_predictive
> observed_data,
:posterior-predictive-samples Inference data with groups:
> posterior_predictive
> observed_data,
:idata Inference data with groups:
> posterior
> sample_stats
> observed_data}
(defn show-results [results {:keys [view-limit]}]
(let [tensor->cljs (fn [tensor aname]
(-> tensor
(tensor/transpose [1 0])
util/xyz-tensor->dataset
(tc/head view-limit)
util/prep-dataset-for-cljs))
shape (-> results
:idata
(py.- posterior)
(py.- prot1_adapted)
np/shape)
n-chains (first shape)
n-samples (second shape)]
(->> {:prot1-adapted-datasets
(-> results
:idata
(py.- posterior)
(py.- prot1_adapted)
util/py-array->clj
(tensor/slice 1)
(->> (map-indexed
(fn [chain-idx chain-tensor]
(-> chain-tensor
(tensor/slice 1)
(->> (map #(tensor->cljs
%
(str "prot1-adapted-chain"
chain-idx)))))))
(apply concat)
vec))
:prot1-chain-idx (->> n-chains
range
(mapcat (fn [chain-idx]
(repeat n-samples chain-idx)))
vec)
:prot2-dataset
(-> results
:structures
second
(tensor->cljs "prot2"))}
(vector '(fn [{:keys [prot1-adapted-datasets
prot1-chain-idx
prot2-dataset]}]
[plotly
{:data (->> prot1-adapted-datasets
(map (fn [dataset]
(-> dataset
(merge {:type :scatter3d
:mode :lines+markers
:opacity 0.1
:marker {:size 3
:color
(mapv
["blue"
"yellow"
"red"
"green"]
prot1-chain-idx)}}))))
(cons (-> prot2-dataset
(merge {:type :scatter3d
:mode :lines+markers
:opacity 1
:marker {:size 3
:color "orange"}})))
vec)}]))
kind/hiccup)))(-> {:residues-limit 100 :tune 200}
model
(show-results {:view-limit 50}))(-> {:residues-limit 100 :tune 50}
model
(show-results {:view-limit 50}))(-> {:residues-limit 100 :tune 15}
model
(show-results {:view-limit 50}))(-> {:residues-limit 100 :tune 5}
model
(show-results {:view-limit 50})):bye:bye